#!/usr/bin/env python
# -*- coding: utf-8 -*-
import random
import numpy as np
import copy
import torch
from torchvision import datasets, transforms

from common import *

def init_deterministic():
    # call init_deterministic() in each run_experiments function call

    torch.manual_seed(1234)
    np.random.seed(1234)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(1234)

def sign(grad):
    return [torch.sign(update) for update in grad]

def sto(t,args):
    rv=0.5*(1+torch.clip(t,-args.b,args.b)/(args.b*(1+args.beta)))
    ones=2*torch.bernoulli(rv)-1
    return ones

def flatten(grad_update):
    return torch.cat([update.data.view(-1) for update in grad_update])

def unflatten(flattened, normal_shape):
    grad_update = []
    for param in normal_shape:
        n_params = len(param.view(-1))
        grad_update.append(torch.as_tensor(flattened[:n_params]).reshape(param.size())  )
        flattened = flattened[n_params:]

    return grad_update

def stosign(args,grad):
    result=[]
    for update in grad:
        result.append(sto(update,args))
    return result

def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = 'data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)
        if args.iid:
            user_groups = assign_data(train_dataset,args.num_users,float('infinity'))
        else:
            user_groups = assign_data(train_dataset,args.num_users,args.alpha)
                    
    elif args.dataset == 'mnist':
        data_dir = 'data/mnist/'
        apply_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                    transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                    transform=apply_transform)
        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = assign_data(train_dataset,args.num_users,float('infinity'))
        else:
            user_groups = assign_data(train_dataset,args.num_users,args.alpha)

    return train_dataset, test_dataset, user_groups

def assign_data(dataset,pool_size=1,alpha=1,**kwargs):
    user_dataidx_map={}
    images, labels = dataset,np.array(dataset.targets)
    idx = np.array(range(len(labels)))
    dataset = [idx, labels]
    partitions, _ = create_lda_partitions(
        dataset, num_partitions=pool_size, concentration=alpha, accept_imbalanced=True
    )
    for p in range(pool_size):

        labels = partitions[p][1]
        image_idx = partitions[p][0]
        user_dataidx_map[p]=image_idx
       
    return user_dataidx_map

def momentum(model, velocity, grad, lr):
    gamma = .9
    layer_no=0
    for param_model, param_update in zip(model.parameters(), grad):
        velocity[layer_no] = gamma * velocity[layer_no] + lr * param_update.data
        param_model.data -= velocity[layer_no]
        layer_no+=1
    return model,velocity

def average_weights(w,Byzantine):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            if i in Byzantine:
                w_avg[key] -= w[i][key]
            else:
                w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

def compute_grad_update(args, old_model, new_model, lr, device=None):
    # maybe later to implement on selected layers/parameters
    if device:
        old_model, new_model = old_model.to(device), new_model.to(device)
    return [(new_param.data - old_param.data)/(-lr) for old_param, new_param in zip(old_model.parameters(), new_model.parameters())]

def add_gradient_updates(grad_update_1, grad_update_2, weight = 1.0):
    assert len(grad_update_1) == len(grad_update_2), "Lengths of the two grad_updates not equal"
    
    for param_1, param_2 in zip(grad_update_1, grad_update_2):
        param_1.data += param_2.to('cuda:0').data * weight

def add_update_to_model(model, update, wd=0,weight=1.0, device=None):
    if not update: return model
    if device:
        model = model.to(device)
        update = [param.to(device) for param in update]
            
    for param_model, param_update in zip(model.parameters(), update):
        # param_model.data -= 1e-5*param_model.data
        param_model.data += weight * (param_update.data+wd*param_model.data)
    return model

def aggregate_signsgd(epoch, args, global_model, grad_updates, lr, device=None, residual_error=None, test_dataset=None, user_groups=None):
    if grad_updates:
        len_first = len(grad_updates[0])
        assert all(len(i) == len_first for i in grad_updates), "Different shapes of parameters. Cannot aggregate."
    else:
        return

    grad_updates_ = [copy.deepcopy(grad_update) for i, grad_update in enumerate(grad_updates)]
    aggregated_gradient_updates=[]

    if device:
        for i, grad_update in enumerate(grad_updates_):
            grad_updates_[i] = [param.to(device) for param in grad_update]

    server_update = [torch.zeros(grad.shape, device=grad.device) for grad in grad_updates_[0]]
    if args.mode=='FedAvg':
        all_records=0
        for i in range(args.num_users):
            all_records+=len(user_groups[i])
        for i in range(len(grad_updates)):
            if args.weighted:
                add_gradient_updates(server_update, grad_updates[i], weight = len(user_groups[i])/all_records)
            else:
                add_gradient_updates(server_update, grad_updates[i],weight=1/len(grad_updates))
        aggregated_sgd=server_update
    if args.mode=='SIGNSGD':
        for i in range(len(grad_updates)):
            add_gradient_updates(server_update, sign(grad_updates[i]))
        aggregated_signsgd=sign(server_update)
    if args.mode=='STO-SIGNSGD':
        print(args.mode)
        for i in range(len(grad_updates)):
            add_gradient_updates(server_update, stosign(args, grad_updates[i]))
        aggregated_signsgd=sign(server_update)
    if args.mode=='FedAvg':
        add_update_to_model(global_model, aggregated_sgd, wd=args.weight_decay,weight=-1.0 * lr)
    else:
        add_update_to_model(global_model, aggregated_signsgd,wd=args.weight_decay, weight=-1.0 * lr)
    
    return global_model

def exp_details(args):
    print('\nExperimental details:')
    print(f'    Model     : {args.model}')
    print(f'    Optimizer : {args.optimizer}')
    print(f'    Learning  : {args.lr}')
    print(f'    Global Rounds   : {args.epochs}\n')

    print('    Federated parameters:')
    if args.iid:
        print('    IID')
    else:
        print('    Non-IID')
    print(f'    Fraction of users  : {args.frac}')
    print(f'    Local Batch size   : {args.local_bs}')
    print(f'    Local Epochs       : {args.local_ep}\n')
    return
